"Waste Sorting using Fast AI v2"¶

"Automatic Waste Sorting using AI"

  • author: Amay Trivedi https://github.com/amay1212/Waste-Sorting/blob/master/Waste_Sorter%20Extended.ipynb

This notebook is an extension of Waste Sorter by Collindching and parts of code snippets and code cells have been heavily referred with some tweaks from the above mentioned link.

Primary focus of this notebook to extend the Model's performance by improving the accuracy and reducing the misclassified error which was noticed earlier.

This notebook is built on fast.ai v2 library. More information can be found here https://course.fast.ai/. Current version of Collindching's notebook supported fast ai v1 version. So, some tweaks have been done to support that.

Why waste sorting?¶

Recycling contamination occurs when waste is incorrectly disposed of - like recycling a pizza box with oil on it (compost). Or when waste is correctly disposed of but incorrectly prepared - like recycling unrinsed jam jars.

Contamination is a huge problem in the recycling industry that can be mitigated with automated waste sorting. Just for kicks, I thought I'd try my hand at prototyping an image classifier to classify trash and recyclables - this classifier could have applications in an optical sorting system.

Waste Classifier Model Pipeline¶

In this project, I'll try to reduce the misclassification error which was noticed earlier (link mentioned above)

We will follow the same prior steps

  1. Extract data
  2. Investigate on why more misclassification happened in the first place.
  3. Model Data
  4. Predictions on New Images
  5. Comparing our Results to previous version of Notebook
  6. Further enhancements and Research

The below code cell is simply a upgrade step of fastai library, this is done to ensure we have latest fixes in one place. I noticed while plotting top losses of images that there were few empty plots, and this was the quick fix I could find. So, to be on safer side this additional step is performed.

In [ ]:
import warnings
warnings.filterwarnings('ignore')
!pip install --upgrade git+https://github.com/fastai/fastai.git
Collecting git+https://github.com/fastai/fastai.git
  Cloning https://github.com/fastai/fastai.git to /tmp/pip-req-build-s_v1f852
  Running command git clone -q https://github.com/fastai/fastai.git /tmp/pip-req-build-s_v1f852
Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (21.1.3)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (21.3)
Requirement already satisfied: fastdownload<2,>=0.0.5 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (0.0.5)
Requirement already satisfied: fastcore<1.4,>=1.3.27 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.3.27)
Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (0.11.1+cu111)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (3.2.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.3.5)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (2.23.0)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (3.13)
Requirement already satisfied: fastprogress>=0.2.4 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.0.0)
Requirement already satisfied: pillow>6.0.0 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (7.1.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.0.2)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.4.1)
Requirement already satisfied: spacy<4 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (2.2.4)
Requirement already satisfied: torch<1.11,>=1.7.0 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.10.0+cu111)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from fastprogress>=0.2.4->fastai==2.5.4) (1.21.5)
Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (0.4.1)
Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.1.3)
Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.0.5)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (3.0.6)
Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (7.4.0)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.0.6)
Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (0.9.0)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (4.62.3)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (2.0.6)
Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.0.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (57.4.0)
Requirement already satisfied: importlib-metadata>=0.20 in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<4->fastai==2.5.4) (4.11.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20->catalogue<1.1.0,>=0.0.7->spacy<4->fastai==2.5.4) (3.7.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20->catalogue<1.1.0,>=0.0.7->spacy<4->fastai==2.5.4) (3.10.0.2)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (1.24.3)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (1.3.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (0.11.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (2.8.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (3.0.7)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->fastai==2.5.4) (1.15.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->fastai==2.5.4) (2018.9)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fastai==2.5.4) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fastai==2.5.4) (1.1.0)
In [ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [ ]:
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

from fastbook import *
from fastai.vision.widgets import *
from pathlib import Path
from glob2 import glob
from sklearn.metrics import confusion_matrix

import pandas as pd
import numpy as np
import os
import zipfile as zf
import shutil
import re
import seaborn as sns

1. Extract data¶

First, we need to extract the contents of "dataset-resized.zip".

In [ ]:
# Alternatively using the code to work it through, by mounting the google drive..
from google.colab import drive
drive.mount('/content/gdrive')
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
In [ ]:
# Copying and extracting the dataset fetched from the site..
!cp -r gdrive/MyDrive/dataset-resized.zip .
files = zf.ZipFile("dataset-resized.zip",'r')
files.extractall()
files.close()

Once unzipped, the dataset-resized folder has six subfolders:

In [ ]:
# A basic sanity check to ensure the waste categories (images) are in place or not.
os.listdir(os.path.join(os.getcwd(),"dataset-resized"))
Out[ ]:
['plastic', 'trash', 'cardboard', 'metal', 'glass', 'paper', '.DS_Store']

2. Organize images into different folders¶

Now that we've extracted the data, I'm going to split images up into train, validation, and test image folders with a 50-25-25 split. First, I'll define some functions that will help me quickly build it. If you're not interested in building the data set, you can just run this ignore it.

In [ ]:
?random
In [ ]:
f = os.path.join('dataset-resized', 'plastic')
n = len(os.listdir(f))
k = random.sample(list(range(1,n+1)),int(.5*n))
print(k)
[254, 390, 231, 242, 334, 195, 404, 108, 49, 250, 15, 458, 428, 200, 222, 312, 391, 393, 2, 357, 229, 137, 370, 411, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 372, 472, 271, 114, 392, 225, 482, 284, 120, 177, 119, 347, 113, 481, 236, 149, 453, 214, 285, 329, 52, 96, 323, 371, 152, 62, 381, 171, 460, 365, 257, 445, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 418, 207, 213, 341, 89, 188, 281, 360, 346, 378, 192, 45, 439, 340, 261, 56, 84, 267, 476, 190, 251, 376, 454, 241, 23, 158, 361, 315, 304, 297, 399, 332, 88, 87, 258, 117, 7, 103, 277, 465, 434, 208, 264, 435, 296, 181, 430, 138, 338, 440, 467, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 292, 437, 362, 405, 212, 249, 183, 397, 178, 1, 276, 373, 320, 314, 170, 235, 308, 443, 471, 91, 282, 300, 93, 47, 283, 131, 17, 37, 43, 9, 232, 8, 144, 128, 447, 57, 95, 356, 429, 36, 86, 82, 349, 442, 366, 140, 151, 233, 165, 255, 243, 59, 13, 160, 198, 176, 216, 97, 133, 384, 130, 262, 475, 468, 11, 116, 10, 204, 75, 19, 83, 462, 412, 396, 432, 480, 115, 377, 203, 289, 463, 31, 189, 77, 33, 248, 55, 387, 426, 79, 268, 220, 20, 80, 322, 431]
In [ ]:
# Cite - https://github.com/collindching/waste-sorter/blob/master/Waste%20sorter.ipynb
## helper functions ##

## splits indices for a folder into train, validation, and test indices with random sampling
    ## input: folder path
    ## output: train, valid, and test indices    
def split_indices(folder,seed1,seed2):    
    n = len(os.listdir(folder))
    print("Folder path:{}".format(folder))
    full_set = list(range(1,n+1))

    ## train indices
    random.seed(seed1)
    train = random.sample(list(range(1,n+1)),int(.5*n))

    ## temp
    remain = list(set(full_set)-set(train))

    ## separate remaining into validation and test
    random.seed(seed2)
    valid = random.sample(remain,int(.5*len(remain)))
    test = list(set(remain)-set(valid))
    print("List of indices\n {}.\n.{}\n.{}".format(train, valid, test))
    
    return(train,valid,test)

## gets file names for a particular type of trash, given indices
    ## input: waste category and indices
    ## output: file names 
def get_names(waste_type,indices):
    file_names = [waste_type+str(i)+".jpg" for i in indices]
    return(file_names)    

## moves group of source files to another folder
    ## input: list of source files and destination folder
    ## no output
def move_files(source_files,destination_folder):
    for file in source_files:
        shutil.move(file,destination_folder)

Next, We will follow the same convention as Imagenet architecture

/data
     /train
         /cardboard
         /glass
         /metal
         /paper
         /plastic
         /trash
     /valid
         /cardboard
         /glass
         /metal
         /paper
         /plastic
         /trash
    /test

Each image file is just the material name and a number (i.e. cardboard1.jpg)

In [ ]:
# Removing the folder if existed before..
path = Path(os.getcwd()+"/data")
subset = ['train', 'valid', 'test']
[shutil.rmtree(os.path.join(path, sub_folder)) for sub_folder in subset if os.path.exists(os.path.join(path, sub_folder))]
Out[ ]:
[None, None, None]
In [ ]:
#Cite - https://github.com/collindching/waste-sorter/blob/master/Waste%20sorter.ipynb
## paths will be train/cardboard, train/glass, etc...
subsets = ['train','valid']
waste_types = ['cardboard','glass','metal','paper','plastic','trash']

def verify_source_file(waste_type, dataset_ind):
  dataset_names = get_names(waste_type, dataset_ind)
  source_files = []
  for name in dataset_names:
    path = os.path.join(source_folder,name)
    if not os.path.exists(path):
      continue
    source_files.append(path)

  return source_files


## create destination folders for data subset and waste type
for subset in subsets:
    for waste_type in waste_types:
        folder = os.path.join('data',subset,waste_type)
        if not os.path.exists(folder):
            os.makedirs(folder)
            
if not os.path.exists(os.path.join('data','test')):
    os.makedirs(os.path.join('data','test'))
            
## move files to destination folders for each waste type
for waste_type in waste_types:
    source_folder = os.path.join('dataset-resized',waste_type)
    train_ind, valid_ind, test_ind = split_indices(source_folder,1,1)
    train_source_files = verify_source_file(waste_type, train_ind)
    train_dest = "data/train/"+waste_type
    move_files(train_source_files,train_dest)
    
    ## move source files to valid
    valid_source_files = verify_source_file(waste_type, valid_ind)
    valid_dest = "data/valid/"+waste_type
    move_files(valid_source_files,valid_dest)
    
    ## move source files to test
    test_source_files = verify_source_file(waste_type, test_ind)
    move_files(test_source_files,"data/test")
Folder path:dataset-resized/cardboard
List of indices
 [69, 292, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 108, 49, 250, 15, 200, 222, 312, 2, 357, 229, 137, 370, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 388, 271, 114, 225, 397, 284, 120, 177, 119, 347, 113, 236, 149, 374, 214, 285, 329, 52, 96, 323, 152, 62, 171, 257, 366, 260, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 207, 213, 89, 188, 281, 192, 45, 362, 261, 56, 84, 267, 372, 190, 251, 375, 241, 23, 158, 304, 297, 316, 88, 87, 258, 117, 7, 103, 277, 324, 383, 208, 264, 358, 181, 354, 138, 299, 3, 197, 263, 67, 266, 106, 219, 29, 247, 187, 336, 393, 212, 249, 183, 327, 178, 1, 170, 235, 365, 379, 91, 93, 47, 399, 17, 37, 43, 9, 232, 8, 144, 128, 402, 283, 205, 160, 48, 326, 75, 331, 262, 41, 66, 136, 244, 44, 169, 70, 166, 276, 76, 378, 180, 83, 342, 122, 30, 332, 80, 99, 306, 391, 204, 396, 287, 28, 65, 291, 265, 54, 338, 367, 210, 6, 58, 380, 102, 38, 10, 185, 42, 115, 294, 130, 174, 110].
.[79, 307, 389, 39, 142, 73, 269, 387, 237, 245, 341, 201, 123, 60, 255, 24, 206, 227, 319, 4, 359, 234, 147, 132, 314, 64, 167, 373, 21, 22, 384, 290, 13, 382, 126, 224, 351, 286, 127, 230, 400, 296, 133, 182, 325, 337, 239, 155, 350, 221, 298, 63, 105, 157, 394, 175, 272, 348, 273, 107, 161, 153, 270, 300, 209, 25, 252, 139, 216, 315, 100, 193, 198, 57, 335, 173, 395, 34, 371, 51, 143, 223, 112, 311, 282, 274, 19, 134, 317, 90, 186, 305, 162, 310, 318, 243, 168, 55, 218, 141, 72]
.[11, 20, 26, 27, 31, 32, 35, 36, 40, 46, 50, 59, 68, 71, 74, 77, 78, 81, 82, 85, 86, 92, 94, 95, 97, 101, 104, 109, 116, 121, 124, 129, 135, 140, 145, 148, 150, 151, 154, 159, 164, 165, 172, 176, 179, 184, 189, 191, 194, 199, 203, 211, 215, 220, 226, 228, 233, 238, 240, 248, 253, 268, 275, 279, 280, 288, 289, 293, 295, 308, 309, 313, 320, 321, 322, 328, 330, 339, 340, 343, 344, 345, 346, 349, 353, 355, 356, 360, 361, 363, 364, 368, 369, 376, 377, 381, 385, 386, 398, 401, 403]
Folder path:dataset-resized/glass
List of indices
 [69, 292, 434, 411, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 404, 108, 49, 250, 15, 458, 428, 200, 222, 312, 391, 393, 2, 357, 229, 137, 370, 492, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 452, 196, 352, 111, 217, 372, 477, 271, 114, 491, 225, 487, 284, 120, 177, 119, 347, 113, 486, 236, 149, 476, 214, 429, 285, 329, 52, 96, 323, 371, 152, 62, 381, 171, 465, 365, 257, 449, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 421, 207, 213, 341, 89, 188, 281, 360, 346, 378, 192, 45, 443, 340, 261, 56, 84, 267, 407, 190, 251, 376, 459, 241, 23, 158, 361, 315, 304, 297, 384, 332, 88, 87, 258, 117, 7, 103, 277, 396, 438, 208, 264, 439, 296, 181, 493, 138, 338, 363, 472, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 494, 441, 418, 408, 212, 249, 183, 400, 178, 1, 276, 364, 320, 314, 170, 235, 308, 447, 463, 91, 282, 300, 93, 47, 283, 489, 17, 37, 43, 9, 232, 8, 144, 128, 355, 57, 95, 359, 433, 36, 86, 82, 375, 446, 369, 140, 151, 233, 165, 255, 243, 59, 13, 160, 198, 176, 216, 97, 133, 387, 130, 262, 480, 473, 11, 116, 10, 204, 75, 19, 83, 467, 415, 414, 436, 265, 485, 115, 380, 203, 389, 385, 31, 153, 65, 248, 55, 435, 394, 79, 373, 220].
.[73, 307, 440, 420, 403, 40, 136, 68, 270, 402, 234, 244, 339, 193, 413, 110, 54, 268, 26, 460, 431, 199, 224, 322, 483, 405, 4, 362, 230, 142, 379, 484, 124, 317, 60, 166, 461, 24, 25, 470, 293, 20, 469, 354, 121, 221, 382, 427, 289, 122, 453, 227, 478, 298, 126, 175, 401, 351, 409, 238, 154, 426, 215, 299, 336, 58, 101, 330, 155, 479, 172, 273, 416, 274, 349, 102, 159, 148, 316, 272, 353, 201, 345, 27, 252, 132, 209, 468, 94, 185, 295, 189, 50, 451, 275, 63, 85, 286, 424, 186, 269, 412, 471, 32, 161, 313, 92, 475, 356, 437, 21, 107, 398, 367, 279, 399, 180, 388, 444, 6, 194, 442, 223, 211]
.[22, 28, 30, 34, 35, 38, 39, 41, 42, 44, 46, 48, 51, 64, 66, 70, 71, 72, 74, 76, 77, 78, 80, 81, 90, 99, 100, 104, 105, 109, 112, 123, 127, 129, 134, 135, 139, 141, 143, 145, 147, 150, 157, 162, 164, 167, 168, 169, 173, 174, 179, 182, 184, 191, 205, 206, 210, 218, 226, 228, 237, 239, 240, 245, 253, 280, 287, 290, 291, 294, 305, 306, 309, 310, 311, 318, 319, 321, 324, 325, 326, 327, 328, 331, 335, 337, 342, 343, 348, 350, 358, 366, 368, 374, 377, 383, 386, 395, 397, 406, 410, 417, 419, 422, 423, 425, 430, 432, 445, 448, 450, 454, 455, 456, 457, 462, 464, 466, 474, 481, 482, 488, 490, 495]
Folder path:dataset-resized/metal
List of indices
 [69, 292, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 108, 49, 250, 15, 200, 222, 312, 391, 2, 357, 229, 137, 370, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 395, 271, 114, 225, 404, 284, 120, 177, 119, 347, 113, 236, 149, 380, 214, 285, 329, 52, 96, 323, 152, 62, 171, 257, 372, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 207, 213, 89, 188, 281, 192, 45, 368, 261, 56, 84, 267, 338, 190, 251, 381, 241, 23, 158, 304, 297, 321, 88, 87, 258, 117, 7, 103, 277, 355, 363, 208, 264, 364, 296, 181, 360, 138, 314, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 306, 339, 212, 249, 183, 332, 178, 1, 170, 235, 371, 385, 91, 93, 47, 406, 17, 37, 43, 9, 232, 8, 144, 128, 299, 57, 95, 300, 75, 336, 265, 41, 66, 136, 244, 44, 169, 70, 166, 279, 76, 308, 180, 83, 346, 122, 30, 307, 80, 99, 311, 398, 204, 397, 291, 28, 65, 366, 268, 54, 343, 373, 210, 6, 58, 376, 102, 38, 10, 185, 42, 115, 298, 130, 174, 110, 140, 309].
.[79, 313, 399, 39, 145, 73, 269, 396, 237, 245, 348, 203, 124, 60, 255, 24, 206, 227, 326, 4, 367, 234, 150, 133, 320, 64, 168, 387, 21, 22, 401, 289, 13, 394, 127, 224, 365, 283, 129, 230, 402, 294, 134, 184, 340, 359, 239, 159, 362, 221, 295, 63, 107, 160, 403, 176, 272, 351, 273, 109, 162, 155, 270, 315, 209, 25, 252, 141, 216, 383, 101, 194, 199, 55, 345, 68, 94, 293, 322, 287, 280, 19, 135, 327, 90, 189, 316, 164, 374, 328, 116, 172, 51, 274, 143, 72, 11, 205, 324, 151, 153, 226]
.[20, 26, 27, 31, 32, 34, 35, 36, 40, 46, 48, 50, 59, 71, 74, 77, 78, 81, 82, 85, 86, 92, 97, 100, 104, 105, 112, 121, 123, 126, 132, 139, 142, 147, 148, 154, 157, 161, 165, 167, 173, 175, 179, 182, 186, 191, 193, 198, 201, 211, 215, 218, 220, 223, 228, 233, 238, 240, 243, 248, 253, 262, 275, 276, 282, 286, 290, 305, 310, 317, 318, 319, 325, 330, 331, 335, 337, 341, 342, 349, 350, 353, 354, 356, 358, 361, 369, 375, 377, 378, 379, 382, 384, 386, 388, 389, 393, 400, 405, 407, 408, 409, 410]
Folder path:dataset-resized/paper
List of indices
 [138, 583, 65, 262, 121, 508, 461, 484, 389, 215, 97, 500, 30, 400, 444, 3, 457, 273, 235, 105, 326, 32, 23, 27, 555, 10, 391, 222, 433, 582, 541, 228, 449, 589, 239, 354, 237, 225, 471, 297, 572, 427, 103, 191, 304, 124, 341, 513, 566, 520, 195, 311, 291, 512, 518, 403, 36, 492, 249, 414, 425, 178, 376, 384, 89, 450, 521, 111, 168, 539, 380, 502, 31, 481, 45, 316, 404, 175, 173, 515, 233, 13, 205, 277, 472, 441, 281, 119, 208, 264, 177, 488, 434, 296, 181, 236, 466, 594, 338, 561, 312, 491, 374, 579, 197, 402, 439, 421, 454, 528, 524, 263, 415, 67, 266, 399, 288, 106, 219, 29, 247, 446, 187, 292, 284, 552, 259, 212, 536, 417, 183, 213, 533, 1, 276, 511, 320, 525, 314, 170, 576, 308, 15, 412, 118, 574, 91, 282, 300, 93, 509, 47, 409, 283, 442, 436, 419, 131, 17, 431, 345, 37, 43, 9, 232, 8, 387, 428, 144, 128, 497, 57, 440, 458, 95, 504, 149, 538, 86, 82, 437, 271, 87, 337, 140, 332, 365, 151, 514, 360, 165, 255, 243, 59, 547, 160, 198, 176, 216, 584, 133, 56, 130, 537, 591, 108, 543, 567, 11, 116, 569, 204, 75, 19, 369, 83, 229, 361, 260, 348, 476, 279, 113, 323, 356, 265, 231, 115, 269, 333, 16, 203, 346, 295, 540, 496, 324, 568, 522, 153, 592, 109, 25, 157, 546, 40, 159, 355, 482, 214, 290, 392, 485, 5, 478, 20, 303, 112, 556, 499, 88, 319, 261, 339, 194, 469, 462, 51, 477, 294, 553, 359, 100, 253, 54, 200, 152, 468, 256, 435, 167, 206, 145, 531, 81, 329, 526, 289, 70, 174, 220, 110, 137, 50, 544, 495, 438].
.[78, 501, 42, 192, 573, 418, 381, 397, 328, 148, 58, 410, 532, 335, 368, 4, 378, 207, 166, 62, 258, 21, 529, 530, 7, 330, 155, 363, 486, 161, 372, 507, 171, 293, 169, 158, 388, 234, 464, 358, 61, 127, 257, 310, 367, 85, 549, 321, 102, 379, 309, 306, 189, 395, 150, 190, 357, 385, 280, 489, 593, 586, 240, 373, 188, 455, 494, 136, 386, 581, 12, 180, 72, 487, 344, 142, 147, 564, 559, 125, 217, 301, 334, 565, 317, 422, 542, 432, 278, 505, 548, 483, 558, 199, 407, 126, 185, 315, 523, 172, 426, 90, 302, 250, 241, 429, 342, 270, 53, 396, 401, 68, 6, 423, 210, 382, 69, 480, 193, 114, 230, 120, 473, 510, 275, 305, 246, 479, 134, 299, 550, 196, 223, 470, 154, 453, 182, 364, 226, 218, 307, 498, 146, 184, 122, 416, 298, 2]
.[516, 517, 519, 14, 527, 18, 22, 535, 534, 24, 26, 28, 33, 34, 35, 545, 38, 39, 551, 41, 554, 44, 557, 46, 48, 49, 560, 562, 52, 563, 55, 570, 571, 60, 63, 64, 575, 66, 578, 577, 580, 71, 73, 74, 585, 76, 77, 587, 79, 80, 590, 588, 84, 92, 94, 96, 98, 99, 101, 104, 107, 117, 123, 129, 132, 135, 139, 141, 143, 156, 162, 163, 164, 179, 186, 201, 202, 209, 211, 221, 224, 227, 238, 242, 244, 245, 248, 251, 252, 254, 267, 268, 272, 274, 285, 286, 287, 313, 318, 322, 325, 327, 331, 336, 340, 343, 347, 349, 350, 351, 352, 353, 362, 366, 370, 371, 375, 377, 383, 390, 393, 394, 398, 405, 406, 408, 411, 413, 420, 424, 430, 443, 445, 447, 448, 451, 452, 456, 459, 460, 463, 465, 467, 474, 475, 490, 493, 503, 506]
Folder path:dataset-resized/plastic
List of indices
 [69, 292, 434, 411, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 404, 108, 49, 250, 15, 458, 428, 200, 222, 312, 391, 393, 2, 357, 229, 137, 370, 479, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 372, 464, 271, 114, 478, 225, 474, 284, 120, 177, 119, 347, 113, 473, 236, 149, 445, 214, 285, 329, 52, 96, 323, 371, 152, 62, 381, 171, 452, 365, 257, 437, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 410, 207, 213, 341, 89, 188, 281, 360, 346, 378, 192, 45, 431, 340, 261, 56, 84, 267, 396, 190, 251, 446, 241, 23, 158, 361, 315, 304, 297, 373, 332, 88, 87, 258, 117, 7, 103, 277, 385, 426, 208, 264, 427, 296, 181, 422, 138, 338, 353, 459, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 481, 429, 355, 397, 212, 249, 183, 389, 178, 1, 276, 354, 314, 170, 235, 308, 435, 450, 91, 282, 300, 93, 47, 283, 476, 17, 37, 43, 9, 232, 8, 144, 128, 345, 57, 95, 349, 421, 36, 86, 82, 306, 480, 359, 140, 151, 233, 165, 255, 243, 59, 13, 160, 198, 176, 216, 97, 133, 376, 130, 262, 467, 460, 11, 116, 10, 204, 75, 19, 83, 454, 470, 424, 472, 115, 135, 167, 310, 102, 173, 148, 206, 377, 169, 162, 110, 408, 189, 77].
.[70, 313, 439, 417, 403, 39, 136, 64, 273, 402, 240, 253, 348, 205, 413, 107, 51, 270, 26, 433, 210, 230, 326, 469, 405, 4, 374, 239, 142, 384, 471, 123, 321, 55, 168, 447, 24, 25, 456, 295, 20, 455, 367, 112, 227, 386, 425, 291, 121, 237, 463, 305, 124, 184, 394, 364, 400, 245, 154, 416, 224, 307, 342, 54, 94, 336, 155, 465, 179, 275, 475, 279, 99, 159, 150, 320, 274, 356, 211, 343, 27, 268, 129, 220, 379, 85, 194, 299, 201, 48, 399, 280, 58, 79, 289, 337, 199, 272, 462, 457, 31, 161, 477, 81, 294, 362, 423, 21, 104, 387, 328, 388, 186, 383, 244, 482, 175, 145, 350, 325]
.[6, 22, 28, 30, 32, 34, 35, 38, 40, 41, 42, 44, 46, 50, 60, 63, 65, 66, 68, 71, 72, 73, 74, 76, 78, 80, 90, 92, 100, 101, 105, 109, 122, 126, 127, 132, 134, 139, 141, 143, 147, 153, 157, 164, 166, 172, 174, 180, 182, 185, 191, 193, 203, 209, 215, 218, 221, 223, 226, 228, 234, 238, 248, 252, 265, 269, 286, 287, 290, 293, 298, 309, 311, 316, 317, 318, 319, 322, 324, 327, 330, 331, 335, 339, 351, 358, 363, 366, 368, 369, 375, 380, 382, 395, 398, 401, 406, 407, 409, 412, 414, 415, 418, 419, 420, 430, 432, 436, 438, 440, 441, 442, 443, 444, 448, 449, 451, 453, 461, 466, 468]
Folder path:dataset-resized/trash
List of indices
 [35, 17, 66, 31, 127, 116, 121, 98, 54, 25, 63, 4, 115, 107, 50, 56, 78, 130, 99, 1, 90, 58, 137, 93, 103, 30, 76, 14, 41, 126, 3, 108, 84, 70, 2, 49, 88, 28, 55, 114, 106, 68, 29, 57, 64, 71, 112, 45, 91, 87, 95, 59, 38, 124, 129, 72, 13, 24, 85, 16, 43, 65, 119, 111, 128, 39, 37, 120].
.[20, 11, 48, 18, 113, 101, 46, 73, 33, 86, 135, 9, 47, 132, 94, 105, 40, 62, 82, 83, 131, 77, 42, 125, 136, 61, 117, 23, 97, 69, 67, 52, 104, 96]
.[5, 6, 7, 8, 133, 10, 134, 12, 15, 19, 21, 22, 26, 27, 32, 34, 36, 44, 51, 53, 60, 74, 75, 79, 80, 81, 89, 92, 100, 102, 109, 110, 118, 122, 123]

I set the seed for both random samples to be 1 for reproducibility. Now that the data's organized, we can get to model training.

Important things before we go any further, fast.ai uses Path extensively for handling image path which makes more convenient to not worry about the directory where the images have to be loaded from.

In short, path is just the working directory where temporary files/models will be saved.

Path is used extensively in fastai reference.

In [ ]:
## Get the images from particular path...
path = Path(os.getcwd())/"data"
path
Out[ ]:
Path('/content/data')
In [ ]:
# Checking the counts of plastic, glass in train and validation set..
print("Plastic {}".format(len(os.listdir(str(path)+'/train/plastic'))))
print("Plastic {}".format(len(os.listdir(str(path)+'/valid/plastic'))))

print("Glass tr {}".format(len(os.listdir(str(path)+'/train/glass'))))
print("Glass Vl {}".format(len(os.listdir(str(path)+'/valid/glass'))))
Plastic 241
Plastic 120
Glass tr 242
Glass Vl 124

It seems like we are getting a normal split between train and valid set for glass images. Will look further to improve classification problem by augmentation and so on..

In [ ]:
#hide
tr_plastic_imgs = os.listdir(str(path)+'/train/plastic')
In [ ]:
os.listdir(str(path) + '/train/glass')
Out[ ]:
['glass151.jpg',
 'glass49.jpg',
 'glass106.jpg',
 'glass246.jpg',
 'glass446.jpg',
 'glass359.jpg',
 'glass163.jpg',
 'glass378.jpg',
 'glass1.jpg',
 'glass36.jpg',
 'glass297.jpg',
 'glass117.jpg',
 'glass178.jpg',
 'glass361.jpg',
 'glass82.jpg',
 'glass146.jpg',
 'glass188.jpg',
 'glass340.jpg',
 'glass396.jpg',
 'glass428.jpg',
 'glass15.jpg',
 'glass233.jpg',
 'glass271.jpg',
 'glass276.jpg',
 'glass83.jpg',
 'glass380.jpg',
 'glass494.jpg',
 'glass312.jpg',
 'glass256.jpg',
 'glass87.jpg',
 'glass360.jpg',
 'glass308.jpg',
 'glass93.jpg',
 'glass389.jpg',
 'glass487.jpg',
 'glass258.jpg',
 'glass152.jpg',
 'glass14.jpg',
 'glass385.jpg',
 'glass370.jpg',
 'glass404.jpg',
 'glass52.jpg',
 'glass59.jpg',
 'glass18.jpg',
 'glass67.jpg',
 'glass443.jpg',
 'glass341.jpg',
 'glass75.jpg',
 'glass393.jpg',
 'glass363.jpg',
 'glass434.jpg',
 'glass381.jpg',
 'glass447.jpg',
 'glass86.jpg',
 'glass229.jpg',
 'glass433.jpg',
 'glass91.jpg',
 'glass61.jpg',
 'glass255.jpg',
 'glass103.jpg',
 'glass347.jpg',
 'glass400.jpg',
 'glass394.jpg',
 'glass56.jpg',
 'glass449.jpg',
 'glass247.jpg',
 'glass301.jpg',
 'glass365.jpg',
 'glass197.jpg',
 'glass314.jpg',
 'glass7.jpg',
 'glass196.jpg',
 'glass128.jpg',
 'glass390.jpg',
 'glass200.jpg',
 'glass69.jpg',
 'glass3.jpg',
 'glass111.jpg',
 'glass19.jpg',
 'glass372.jpg',
 'glass235.jpg',
 'glass257.jpg',
 'glass241.jpg',
 'glass212.jpg',
 'glass391.jpg',
 'glass467.jpg',
 'glass130.jpg',
 'glass17.jpg',
 'glass477.jpg',
 'glass296.jpg',
 'glass37.jpg',
 'glass84.jpg',
 'glass97.jpg',
 'glass281.jpg',
 'glass282.jpg',
 'glass5.jpg',
 'glass225.jpg',
 'glass57.jpg',
 'glass114.jpg',
 'glass486.jpg',
 'glass302.jpg',
 'glass476.jpg',
 'glass421.jpg',
 'glass261.jpg',
 'glass95.jpg',
 'glass438.jpg',
 'glass415.jpg',
 'glass492.jpg',
 'glass266.jpg',
 'glass140.jpg',
 'glass463.jpg',
 'glass480.jpg',
 'glass170.jpg',
 'glass11.jpg',
 'glass202.jpg',
 'glass88.jpg',
 'glass118.jpg',
 'glass250.jpg',
 'glass371.jpg',
 'glass436.jpg',
 'glass459.jpg',
 'glass323.jpg',
 'glass285.jpg',
 'glass12.jpg',
 'glass387.jpg',
 'glass418.jpg',
 'glass260.jpg',
 'glass357.jpg',
 'glass355.jpg',
 'glass352.jpg',
 'glass113.jpg',
 'glass204.jpg',
 'glass452.jpg',
 'glass369.jpg',
 'glass300.jpg',
 'glass485.jpg',
 'glass491.jpg',
 'glass315.jpg',
 'glass125.jpg',
 'glass47.jpg',
 'glass181.jpg',
 'glass473.jpg',
 'glass216.jpg',
 'glass414.jpg',
 'glass16.jpg',
 'glass408.jpg',
 'glass292.jpg',
 'glass303.jpg',
 'glass89.jpg',
 'glass133.jpg',
 'glass65.jpg',
 'glass119.jpg',
 'glass264.jpg',
 'glass373.jpg',
 'glass153.jpg',
 'glass45.jpg',
 'glass278.jpg',
 'glass108.jpg',
 'glass267.jpg',
 'glass219.jpg',
 'glass156.jpg',
 'glass183.jpg',
 'glass23.jpg',
 'glass376.jpg',
 'glass165.jpg',
 'glass288.jpg',
 'glass131.jpg',
 'glass203.jpg',
 'glass208.jpg',
 'glass242.jpg',
 'glass320.jpg',
 'glass10.jpg',
 'glass304.jpg',
 'glass222.jpg',
 'glass411.jpg',
 'glass220.jpg',
 'glass283.jpg',
 'glass160.jpg',
 'glass33.jpg',
 'glass263.jpg',
 'glass489.jpg',
 'glass429.jpg',
 'glass31.jpg',
 'glass8.jpg',
 'glass79.jpg',
 'glass435.jpg',
 'glass344.jpg',
 'glass334.jpg',
 'glass493.jpg',
 'glass259.jpg',
 'glass177.jpg',
 'glass375.jpg',
 'glass262.jpg',
 'glass243.jpg',
 'glass392.jpg',
 'glass144.jpg',
 'glass137.jpg',
 'glass458.jpg',
 'glass407.jpg',
 'glass116.jpg',
 'glass62.jpg',
 'glass43.jpg',
 'glass55.jpg',
 'glass329.jpg',
 'glass53.jpg',
 'glass158.jpg',
 'glass441.jpg',
 'glass251.jpg',
 'glass120.jpg',
 'glass198.jpg',
 'glass439.jpg',
 'glass346.jpg',
 'glass277.jpg',
 'glass207.jpg',
 'glass2.jpg',
 'glass333.jpg',
 'glass115.jpg',
 'glass232.jpg',
 'glass213.jpg',
 'glass249.jpg',
 'glass332.jpg',
 'glass171.jpg',
 'glass465.jpg',
 'glass214.jpg',
 'glass384.jpg',
 'glass9.jpg',
 'glass217.jpg',
 'glass231.jpg',
 'glass472.jpg',
 'glass138.jpg',
 'glass265.jpg',
 'glass149.jpg',
 'glass284.jpg',
 'glass96.jpg',
 'glass13.jpg',
 'glass236.jpg',
 'glass187.jpg',
 'glass98.jpg',
 'glass29.jpg',
 'glass364.jpg',
 'glass195.jpg',
 'glass338.jpg']
In [ ]:
# Viewing the image in the dataset
tr_glass_imgs = os.listdir(str(path)+'/train/glass')
img = Image.open(str(path)+'/train/glass/'+tr_glass_imgs[3])

In the previous version of this notebook, glass was more misclassified as metal or plastic.

Investigating On Why Glass is more misclassified than Plastic and Metal¶

In [ ]:
glass_imgs = (path/'train/glass').ls()
im = Image.open(glass_imgs[1])

# converting the first image to tensors...
first_glass = tensor(im)
print(first_glass[1:4,  4:10])
print(first_glass.shape)
tensor([[[230, 210, 186],
         [229, 209, 185],
         [229, 209, 185],
         [229, 209, 185],
         [228, 208, 184],
         [228, 208, 184]],

        [[229, 209, 185],
         [229, 209, 185],
         [229, 209, 185],
         [229, 209, 185],
         [228, 208, 184],
         [228, 208, 184]],

        [[229, 209, 185],
         [229, 209, 185],
         [228, 208, 184],
         [228, 208, 184],
         [228, 208, 184],
         [228, 208, 184]]], dtype=torch.uint8)
torch.Size([384, 512, 3])

We have converted our first image of glass to tensor....

Next steps:

We calculate the mean of the sample glass image and see if we can get closer image of tensor of plastic as well as metal as they were the most misclassified items for our waste classifier

In [ ]:
a = torch.randn(20, 4, 4)
print(a)
random.seed(1)
mean_a = a.mean(0)

print(mean_a.shape)
tensor([[[ 1.9041, -0.6623, -0.0740, -1.8308],
         [ 0.0620, -0.2726,  1.5847, -2.0998],
         [-1.8451, -0.5164, -0.8150,  0.2383],
         [-0.0154, -1.3963, -0.2346,  0.6368]],

        [[-0.4682,  0.7713, -1.9177,  0.5771],
         [ 1.3979,  0.8343,  0.2862,  0.4237],
         [ 1.6290, -0.2072, -0.1179, -0.9172],
         [-0.6613, -1.1318,  0.1150,  2.8163]],

        [[-0.8620, -0.8489, -0.4204,  0.7048],
         [ 0.1674, -0.9637, -1.4947,  0.9189],
         [-0.3080, -3.3565,  1.1957, -0.8564],
         [ 1.2134, -0.4157,  0.0896, -0.1226]],

        [[-0.5561,  1.2022, -0.3723,  0.9290],
         [ 1.0448,  0.6112, -1.1486, -1.2055],
         [-0.5671, -0.7027, -1.4796,  1.7317],
         [-0.5427,  0.8289, -1.6866, -1.9065]],

        [[ 0.7942, -0.7050, -1.6597,  1.6873],
         [ 1.5678, -0.0256,  0.2248,  1.2104],
         [-1.0935, -0.5420, -0.2103, -2.3326],
         [ 0.3916,  0.2181, -2.0667, -0.0830]],

        [[ 1.0960,  0.4500,  0.4413,  0.5867],
         [-1.6896, -0.7949,  0.0723,  0.4627],
         [-0.6459, -1.1081, -0.6236,  0.8584],
         [-0.3111,  0.1923,  0.4706, -1.2359]],

        [[-2.0230,  0.8408,  0.7278, -0.4549],
         [-2.4891, -0.0443,  0.7275, -2.5718],
         [-1.6148, -0.9621,  0.8803,  0.2965],
         [-0.5933, -0.5933, -0.6970, -0.8499]],

        [[-0.4422,  0.0171, -2.3430,  0.7095],
         [-1.2889,  1.5792,  1.1738,  0.6752],
         [-1.3570,  0.1963, -0.9830,  1.6627],
         [ 0.6076,  0.0658,  0.3313, -0.1094]],

        [[-0.0680, -0.5092,  1.9973, -0.2931],
         [-0.8382, -1.0782,  0.4462, -0.0646],
         [-0.5631,  1.4206,  1.7943,  1.8898],
         [ 0.3165,  0.2293,  0.7299, -1.4330]],

        [[ 1.2045,  0.6211, -1.1297,  0.7161],
         [ 2.2381,  0.3150,  1.0000,  2.0494],
         [-0.1351,  1.8014,  0.5025, -0.0968],
         [-1.1127, -0.5012, -0.8887,  0.2562]],

        [[ 0.2549, -0.7717,  0.4916,  1.4921],
         [-0.7116, -0.1897, -1.6547, -0.8994],
         [ 0.1157,  0.8570,  0.7750, -1.0252],
         [ 0.4516,  1.6063, -0.4316, -1.5365]],

        [[ 0.0418, -0.4215, -0.0071,  0.3110],
         [ 0.2383, -1.7163, -0.3315,  2.7632],
         [ 0.6609,  0.0607, -0.9238, -0.9787],
         [ 0.6828, -1.0436, -0.6931,  0.1954]],

        [[-0.3110,  0.2085,  0.4414, -0.5819],
         [ 1.5436, -1.5502, -0.8435, -0.7709],
         [-0.7234, -0.8946, -0.8756, -0.9477],
         [ 0.0690, -2.8165,  0.9283, -0.6112]],

        [[-0.1419, -0.1804,  0.4074, -0.2060],
         [ 1.1673, -0.3224,  0.1169, -1.5055],
         [ 0.4926,  1.4967, -0.1806,  1.6904],
         [-0.3416,  0.1862,  0.4881,  0.3701]],

        [[-0.2108,  0.2191, -1.1873, -1.0986],
         [ 1.4215, -0.3173,  0.0849, -0.7219],
         [-0.3716, -0.9612, -0.3766, -0.9942],
         [ 0.6933,  0.8511,  1.3226, -0.3613]],

        [[ 0.1691, -1.3188,  0.9293, -1.1307],
         [ 0.9253, -1.1454, -0.6973, -0.0113],
         [-1.9940, -0.9999, -1.1215, -0.8117],
         [-0.6660,  0.9826,  1.1000,  1.2608]],

        [[ 1.2468, -0.1739,  0.6360,  0.4991],
         [-0.0972,  1.2482,  0.7075,  2.0349],
         [ 1.3595, -0.5676,  1.0609, -0.4931],
         [ 0.8091,  0.3110, -0.4240, -0.5833]],

        [[ 1.5777, -1.2872, -0.5809, -0.3320],
         [ 1.2174, -2.0281, -0.2087, -0.7707],
         [ 1.2203,  0.3608, -2.1707,  0.9702],
         [ 1.6480, -0.1294,  0.3269,  1.0901]],

        [[-0.0916,  0.2605,  0.1513,  0.3333],
         [ 0.8916, -1.0663, -0.6575, -0.8669],
         [-0.4498, -0.8968,  1.9549,  0.6996],
         [ 0.6076,  0.1456,  1.4862,  0.0192]],

        [[-0.2001,  0.9258,  0.2802, -0.7186],
         [ 0.1723, -1.8113,  0.8975, -1.8717],
         [-0.7931, -0.5257, -0.3883, -0.5874],
         [ 0.6360, -0.6921, -0.7885,  1.1872]]])
torch.Size([4, 4])

The above code explains how 20 random matrices of 4X4 are stacked up and along the first axis we take the mean which is like all 20 (4 X 4) matrices are computed along X and Y axis and we get a mean random matrix of (4X4).

We will use similar process for concatenating the matrices of 384 X 512 X 3 images to the folder length. For example if training set has 240 images of 384 X 512 X 3 channels, this will be combined or stacked up to shape (240, 384, 512, 3) image size. Hence, we can now compute the average pixels of all these 240 images to form one single image formed by mean of all 240 images and that can be compared to a new image with size (384X512X3). This is the base line model which might not give good results comparatively to the state of the art transfer learning models which are used later in the notebook cells.

In [ ]:
glass_tensors = [tensor(Image.open(g_img)) for g_img in glass_imgs]
print(len(glass_tensors))
plastic_imgs = (path/'train/plastic').ls()
plastic_tensors = [tensor(Image.open(g_img)) for g_img in plastic_imgs]
print(len(plastic_tensors))
metal_imgs = (path/'train/metal').ls()
metal_tensors = [tensor(Image.open(g_img)) for g_img in metal_imgs]
print(len(metal_tensors))

# stacking up all the images of glass, plastic and metal...
glass_stack = torch.stack(glass_tensors).float()/255
print(glass_stack.shape, glass_stack.ndim)

plastic_stack = torch.stack(plastic_tensors).float()/255
print(plastic_stack.shape, plastic_stack.ndim)

metal_stack = torch.stack(metal_tensors).float()/255
print(metal_stack.shape, metal_stack.ndim)

mean_glass_tensor = glass_stack.mean(0)
print("mean shape of glass tensor ==>", mean_glass_tensor.shape)

mean_metal_tensor = metal_stack.mean(0)
show_image(mean_metal_tensor)

show_image(mean_glass_tensor)
242
241
205
torch.Size([242, 384, 512, 3]) 4
torch.Size([241, 384, 512, 3]) 4
torch.Size([205, 384, 512, 3]) 4
mean shape of glass tensor ==> torch.Size([384, 512, 3])
Out[ ]:
<matplotlib.axes._subplots.AxesSubplot at 0x7ff696e7e0d0>
In [ ]:
mean_plastic_tensor = plastic_stack.mean(0)
show_image(mean_plastic_tensor)
Out[ ]:
<matplotlib.axes._subplots.AxesSubplot at 0x7ff699b4d690>
In [ ]:
# Example 1: To check which is the closest to mean pizels of glass image....

diff1 = (glass_stack[1] - mean_glass_tensor).abs().mean()
print(diff1)


diff2 = (glass_stack[1] - mean_plastic_tensor).abs().mean()
print(diff2)


diff3 = (glass_stack[1] - mean_metal_tensor).abs().mean()
print(diff3)

# Clearly the lowest distance can be found for the glass tensors, then plastic and metal follow along in ranking..
# Hence this ia a baseline model to classify this as a glass tensor
tensor(0.0943)
tensor(0.0994)
tensor(0.1026)
In [ ]:
# takes the mean across x and y axis
def glass_distance(a,b): return (a-b).abs().mean((-1, -2))
def is_glass(x): return glass_distance(x,mean_glass_tensor[:, :, 0]) < glass_distance(x,mean_plastic_tensor[:, :, 0])
In [ ]:
is_glass(plastic_stack[-1][:, :, 0]), is_glass(plastic_stack[-1][:, :, 0]).float()
Out[ ]:
(tensor(False), tensor(0.))
In [ ]:
valid_path = (path/'valid/glass').ls()
print(len(valid_path))
valid_path = (path/'valid/plastic').ls()
print(len(valid_path))

valid_path = (path/'train/glass').ls()
print(len(valid_path))
valid_path = (path/'train/plastic').ls()
print(len(valid_path))

#######

valid_path = (path/'valid/plastic').ls()
plastic_imgs = [Image.open(img) for img in valid_path]
plastic_valid_tens = torch.stack([tensor(Image.open(img)) for img in valid_path])
plastic_valid_tens = plastic_valid_tens.float()/255
valid_plastic_tens = plastic_valid_tens[:, :, :, 0]

trues  = [i for i in is_glass(valid_plastic_tens) if i == True]
false = [i for i in is_glass(valid_plastic_tens) if i == False]
total = trues + false
all = len(total)
print(len(trues), len(false))
124
120
242
241
68 52

Accuracy of our initial base line model constructed on the fact of comparing average image pixels with new glass or metal image.

In [ ]:
valid_path = (path/'valid/glass').ls()
glass_imgs = [Image.open(img) for img in valid_path]
#print(valid_path)
glass_valid_tens = torch.stack([tensor(Image.open(img)) for img in valid_path])
glass_valid_tensors = glass_valid_tens.float()/255

valid_glass_tens = glass_valid_tensors[:, :, :, 0]
accuracy_glass =      is_glass(valid_glass_tens).float() .mean()
accuracy_plastic = (1 - is_glass(valid_plastic_tens).float()).mean()

# simple accuracy formula to compute...
print(accuracy_glass,accuracy_plastic,(accuracy_plastic+accuracy_glass)/2)

A simplistic implementation of creating a base line model to see why model can get confuse between metal, glass images because the distances of these tensors are very close to each other. In simpler terms these are most correlated ones..

Let's try to put our model into implementation by first apply augmented transforms to it and then putting these details into our model. Here are some details about the function

In [ ]:
doc(aug_transforms)

aug_transforms[source]

aug_transforms(mult=1.0, do_flip=True, flip_vert=False, max_rotate=10.0, min_zoom=1.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75, xtra_tfms=None, size=None, mode='bilinear', pad_mode='reflection', align_corners=True, batch=False, min_scale=1.0)

Utility func to easily create a list of flip, rotate, zoom, warp, lighting transforms.

Type Default
mult float 1.0
do_flip bool True
flip_vert bool False
max_rotate float 10.0
min_zoom float 1.0
max_zoom float 1.1
max_lighting float 0.2
max_warp float 0.2
p_affine float 0.75
p_lighting float 0.75
xtra_tfms NoneType None
size NoneType None
mode str bilinear
pad_mode str reflection
align_corners bool True
batch bool False
min_scale float 1.0

Show in docs

Next Important class is ImageDataLoaders.

ImageDataLoaders - Wrapper around the DataLoaders with factory methods for computer visions.

In simpler terms and ImageDataLoaders have several helper functions which can easily load the data as a DataLoader object.

According to the docs of fast.ai ImageDataLoader Doc

This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:

  1. item_tfms: one or several transforms applied to the items before batching them
  2. batch_tfms: one or several transforms applied to the batches once they are formed
  3. bs: the batch size
  4. val_bs: the batch size for the validation DataLoader (defaults to bs)
  5. shuffle_train: if we shuffle the training DataLoader or not device: the PyTorch device to use (defaults to default_device())
In [ ]:
doc(ImageDataLoaders)

class ImageDataLoaders[source]

ImageDataLoaders(*loaders, path='.', device=None) :: DataLoaders

Basic wrapper around several DataLoaders with factory methods for computer vision problems

Type Default
loaders
path str .
device NoneType None

Show in docs

In [ ]:
tfms = aug_transforms(do_flip=True,flip_vert=True)
data= ImageDataLoaders.from_folder(path, train = "train", valid = "valid", 
                                    batch_tfms=[*tfms, Normalize.from_stats(*imagenet_stats)],bs = 16)

Important Note from Transformations perspective

  1. Normalize.from_stats(*imagenet_state): Here we are saying that normalize each image with respect to imagenet_stats (imagenets dimension across 3 channels), basically this gives us the mean and standard deviation tensors of 3 dimensions.

Idea is to bring each pixel value close to the center, so that data dimensions are of approximately the same scale. We'd like in this process for each feature to have a similar range so that our gradients don't go out of control (and that we only need one global learning rate multiplier).

For example in case of RGB channels, we will do this process for each of the channel by simple formula. Here is a demonstration of how to perform normalization with python using numpy library.

X /= np.std(X, axis = 0)

prepro1.png

image source - https://cs231n.github.io/neural-networks-2/

The batch size bs is how many images you'll train at a time. Similarly, we can specify the valid batch size, which defaults to bs we have provided. Smaller bs will work for computers with less memory.

You can use aug_transforms() function to augment your data. I'll compare the results from flipping images horizontally and vertically.

In [ ]:
print(data.vocab)
['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

show_batch function fast.ai

A sample image can be seen from data.show_batch method

Takes into argument figure size tuple

This function display the batches of images to quickly glance at the data we are playing around with.

Some most common used arguments

  1. rows: To specify number of rows we want to display in our batches of images
In [ ]:
type(data)
Out[ ]:
fastai.data.core.DataLoaders
In [ ]:
# Applying show_batch function to our data loader object 'data' and visualize some of the images
data.show_batch(figsize=(10,8))

Building our Waste Classifier Model¶

If you run the program with CUDA_LAUNCH_BLOCKING=1, this will help get a more exact stack trace

In [ ]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
In [ ]:
# just one line to create our state of the art transfer learning model
learn = cnn_learner(data,models.resnet34,metrics=error_rate)

Cite - Notebook link

What is resnet34?¶

A residual neural network is a convolutional neural network (CNN) with lots of layers. In particular, resnet34 is a CNN with 34 layers that's been pretrained on the ImageNet database. A pretrained CNN will perform better on new image classification tasks because it has already learned some visual features and can transfer that knowledge over (hence transfer learning).

Since they're capable of describing more complexity, deep neural networks should theoretically perform better than shallow networks on training data. In reality, though, deep neural networks tend to perform empirically worse than shallow ones.

Resnets were created to circumvent this glitch using a hack called shortcut connections. If some nodes in a layer have suboptimal values, you can adjust weights and bias; if a node is optimal (its residual is 0), why not leave it alone? Adjustments are only made to nodes on an as-needed basis (when there's non-zero residuals).

When adjustments are needed, shortcut connections apply the identity function to pass information to subsequent layers. This shortens the neural network when possible and allows resnets to have deep architectures and behave more like shallow neural networks. The 34 in resnet34 just refers to the number of layers.

Here is an interesting links for RESNET architecture

https://blog.roboflow.com/custom-resnet34-classification-model/

In [ ]:
doc(learn.model)

5, inplace=False) (8): Linear(in_features=512, out_features=6, bias=False) ) )[source]

5, inplace=False) (8): Linear(in_features=512, out_features=6, bias=False) ) )(*input, **kwargs)

A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an OrderedDict of modules can be passed in. The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then "chains" outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

The value a Sequential provides over manually calling a sequence of modules is that it allows treating the whole container as a single module, such that performing a transformation on the Sequential applies to each of the modules it stores (which are each a registered submodule of the Sequential).

What's the difference between a Sequential and a :class:torch.nn.ModuleList? A ModuleList is exactly what it sounds like--a list for storing Module s! On the other hand, the layers in a Sequential are connected in a cascading way.

Example::

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

Viewing Transformations in action for one image¶

In [ ]:
# taking one image sample
tr_glass_imgs = os.listdir(str(path)+'/train/glass')[0]
ig=PILImage(PILImage.create(str(path)+'/train/glass/'+tr_glass_imgs).resize((600,400)))
In [ ]:
type(array(ig))
Out[ ]:
numpy.ndarray
In [ ]:
# Small example - Consider a tensor 
sample_tensor = torch.randn(5,4)
# next step is to use permute method
after_permute_operation = sample_tensor.permute(1, 0) 
# permute(1, 0) -> this is essentially swapping up the two axis
print(after_permute_operation.shape)
assert after_permute_operation.shape[0] == 4
torch.Size([4, 5])
In [ ]:
array(ig).shape
Out[ ]:
(400, 600, 3)
In [ ]:
# Using TensorImage class to convert the numpy.ndarray to tensors
# permute method - Used to reorder or reorganize the dimensions of an image
timg = TensorImage(array(ig)).permute(2,0,1).float()/255.
# Below function expands the dimension to the new batch size to an existing image shape.
def _batch_ex(bs): return TensorImage(timg[None].expand(bs, *timg.shape).clone())
In [ ]:
# A simple and a short glance into fastai classes and functions related to transformations..
tfms = aug_transforms(do_flip=True)
for i in tfms:
  # Tfms which is a transform object takes into account two class when we pass an argument, do_flip=True
  print("Class ===>>>>", i, i.__getattribute__)
  # We can now create an object of all the transformations class and then pass our tensor form of images, shown in below code...
Class ===>>>> Flip -- {'size': None, 'mode': 'bilinear', 'pad_mode': 'reflection', 'mode_mask': 'nearest', 'align_corners': True, 'p': 0.5}:
encodes: (TensorImage,object) -> encodes
(TensorMask,object) -> encodes
(TensorBBox,object) -> encodes
(TensorPoint,object) -> encodes
decodes:  <method-wrapper '__getattribute__' of Flip object at 0x7ff627724d10>
Class ===>>>> Brightness -- {'max_lighting': 0.2, 'p': 1.0, 'draw': None, 'batch': False}:
encodes: (TensorImage,object) -> encodes
decodes:  <method-wrapper '__getattribute__' of Brightness object at 0x7ff627724850>
In [ ]:
# Performing image transformations through function..
y = _batch_ex(2) 
for t in tfms: 
  # split_idx = 0 refers to we are passing train image.. split_idx = 1, refers to validation in fastai
  y = t(y, split_idx=0)
  
_,axs = plt.subplots(1,2, figsize=(10,8))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)

Choosing a Learning Rate¶

Fast Ai uses one cycle policy which simply means, that learning rate will first start with low value, bouncing back to the large values and then being stable with value lower than the initial LR value, this has shown impressive results as it means we are not ending with either a slow or a very high LR.

To choose the best LR, one must look at the point where loss curve is the steepest. Please note that steepest point doesn't mean the point of minimum loss. It means point where loss is dropping faster..

lr_vs_clr_resnet56.webp

image source - https://iconof.com/1cycle-learning-rate-policy/

In [ ]:
# start_lr = starting learning rate
# end_lr = maximum learning rate at which we want the model to stop finding LR.
learn.lr_find(start_lr=1e-6,end_lr=1e1)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
Out[ ]:
SuggestedLRs(valley=0.0012022644514217973)
In [ ]:
lr_min, lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
# here is a simple way to find some random number between two numbers in an interval, which can help in guessing what might be the optimal LR
import numpy as np
np.random.uniform(1e-4, 1e-3)
Out[ ]:
0.0007015538560113359
In [ ]:
# fit_one_cycle - method to run the model to some certain epochs, lower epochs 
# One epoch - complete cycle of forward propagation/back propagation
learn.fit_one_cycle(20, lr_max=5e-03)
epoch train_loss valid_loss error_rate time
0 1.626455 0.673134 0.244833 01:40
1 0.984484 0.516873 0.189189 01:40
2 0.834137 0.616733 0.178060 01:40
3 0.862364 0.832932 0.243243 01:40
4 0.840459 0.745179 0.225755 01:40
5 0.778823 0.542830 0.181240 01:40
6 0.682324 0.816168 0.244833 01:40
7 0.638330 0.502149 0.155803 01:40
8 0.557140 0.393592 0.117647 01:40
9 0.572882 0.529694 0.165342 01:40
10 0.462756 0.340510 0.109698 01:40
11 0.385636 0.312294 0.112878 01:40
12 0.369283 0.339354 0.114467 01:40
13 0.298273 0.284633 0.084261 01:40
14 0.274431 0.237938 0.076312 01:40
15 0.236636 0.229560 0.071542 01:40
16 0.165803 0.201756 0.068362 01:40
17 0.143642 0.204025 0.069952 01:40
18 0.144287 0.192077 0.066773 01:40
19 0.185571 0.196076 0.071542 01:39
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

The model ran for 20 epochs giving us the minimum loss of 0.066 which is better than the previous model, which was around 0.08. Hence, we are able to reduce the loss and increase the accuracy which will see later.

VIsualizing most incorrect images¶

In [ ]:
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
plt.figure(figsize=(10, 8))
interp.plot_top_losses(4, nrows=2)
plt.show()
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
<Figure size 720x576 with 0 Axes>
In [ ]:
interp.plot_top_losses(10, figsize=(15,11))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

The images here are the ones after the removal of some over exposure of light in images, which was not the case with previous notebook of collindching. This has created an impact as now our model is less confused and hence the loss decreased.

In [ ]:
# Here's a documentation of plot_top_losses  function would look like...
doc(interp.plot_top_losses)

Interpretation.plot_top_losses[source]

Interpretation.plot_top_losses(k, largest=True, **kwargs)

Show k largest(/smallest) preds and losses. k may be int, list, or range of desired results.

Type Default
k
largest bool True
kwargs

Show in docs

/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

In the previous version of this notebook by collindching, The model often confused plastic for glass and confused metal for glass. The list of most confused images is below. Let's see are we able to reduce the overall misclassification error for the categories or not later.

In [ ]:
# checking where our model got most confused in classifying...
interp.most_confused(min_val=2)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
Out[ ]:
[('glass', 'metal', 7),
 ('cardboard', 'paper', 5),
 ('glass', 'plastic', 5),
 ('plastic', 'metal', 4),
 ('plastic', 'paper', 4),
 ('trash', 'paper', 4),
 ('metal', 'paper', 2),
 ('metal', 'plastic', 2),
 ('plastic', 'trash', 2)]

4. Predicting the test data¶

To see how this mode really performs, we need to make predictions on test data. First, I'll make predictions on the test data using the learner.get_preds() method.

Note: learner.predict() only predicts on a single image, while learner.get_preds() predicts on a set of images. I highly recommend reading the documentation to learn more about predict() and get_preds().

In [ ]:
doc(learn.predict)

Learner.predict[source]

Learner.predict(item, rm_type_tfms=None, with_input=False)

Prediction on item, fully decoded, loss function decoded and probabilities

Type Default
item
rm_type_tfms NoneType None
with_input bool False

Show in docs

/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

Loading Test DataLoader for Prediction¶

In the above cells, we have defined data loader for our train and validation set..

However, for our model to predict on the new image data, we will have load into a dataloader object.

We can simply pass the images of test folder into get_image_files function, which will traverse the test folder and get us all the images in test folder.

get_preds is the function which gives us the probabilities of our image being of each class

In [ ]:
#get predictions..
test_dl = data.test_dl(get_image_files(os.path.join(path, 'test')))
preds = learn.get_preds(dl=test_dl)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
print(preds[0].shape)
preds[0]
torch.Size([631, 6])
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
Out[ ]:
TensorBase([[1.7182e-04, 7.1475e-05, 1.5689e-03, 2.3538e-05, 9.9789e-01, 2.7502e-04],
        [8.1237e-08, 5.9949e-08, 1.6182e-08, 9.9999e-01, 1.9660e-07, 1.1921e-05],
        [1.3310e-06, 6.2397e-03, 9.9346e-01, 4.1719e-07, 2.6252e-04, 3.6827e-05],
        ...,
        [5.1307e-06, 5.3394e-05, 9.9495e-01, 1.8174e-04, 9.8943e-05, 4.7127e-03],
        [1.7216e-07, 1.0502e-05, 7.9543e-04, 7.6339e-06, 9.9208e-05, 9.9909e-01],
        [8.3721e-06, 5.6436e-08, 6.0167e-08, 9.9991e-01, 1.5351e-05, 6.4679e-05]])

Converting probabilities to a class names¶

Simple approach - For one set of image, get the maximum probability value among 6 classes.

Choose the one with maximum value along the 1 axis i.e columns

Rows - Probabilities for other sample of images

In [ ]:
## saves the index (0 to 5) of most likely (max) predicted class for each image
max_idxs = np.asarray(np.argmax(preds[0],axis=1))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
# Append the class labels with index found from maximum probability class..
max_idxs = np.asarray(np.argmax(preds[0],axis=1))
classes = data.vocab
print(classes)
yhat = []
for max_idx in max_idxs:
    yhat.append(classes[max_idx])
['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
# A quick check on what get_image_files does...
l = get_image_files(os.path.join(path, 'test'))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
from PIL import Image
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
Image.open(l[0])
Out[ ]:
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
# Validate against the test set...
learn.validate(dl=test_dl)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
Out[ ]:
(#2) [None,None]
In [ ]:
y = []

## convert POSIX paths to string first
for label_path in test_dl.items:
    y.append(str(label_path))
    
# then extract waste type from file path
pattern = re.compile("([a-z]+)[0-9]+")
for i in range(len(y)):
    y[i] = pattern.search(y[i]).group(1)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

Creating a Dataframe of the results observed¶

In [ ]:
## predicted values
print(yhat[0:30])
## actual values
print(y[0:30])
df_dict = {'actual': y, 'predicted': yhat}
pd.DataFrame(df_dict)
['plastic', 'paper', 'metal', 'paper', 'cardboard', 'cardboard', 'paper', 'trash', 'cardboard', 'glass', 'plastic', 'paper', 'paper', 'cardboard', 'glass', 'plastic', 'glass', 'plastic', 'plastic', 'glass', 'paper', 'glass', 'paper', 'trash', 'cardboard', 'paper', 'plastic', 'metal', 'paper', 'paper']
['plastic', 'paper', 'metal', 'paper', 'cardboard', 'cardboard', 'paper', 'trash', 'cardboard', 'glass', 'plastic', 'plastic', 'paper', 'cardboard', 'glass', 'plastic', 'glass', 'plastic', 'plastic', 'glass', 'paper', 'glass', 'paper', 'trash', 'cardboard', 'paper', 'plastic', 'metal', 'paper', 'paper']
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
Out[ ]:
actual predicted
0 plastic plastic
1 paper paper
2 metal metal
3 paper paper
4 cardboard cardboard
... ... ...
626 plastic plastic
627 glass glass
628 metal metal
629 trash trash
630 paper paper

631 rows × 2 columns

It looks the first five predictions match up! (check)

How did we end up doing? Again we can use a confusion matrix to find out.

Test confusion matrix¶

In [ ]:
cm = confusion_matrix(y,yhat)
print(cm)
[[ 96   0   2   2   1   0]
 [  0 111   5   0   6   0]
 [  0   5  98   0   0   0]
 [  1   0   0 145   1   2]
 [  0   2   0   1 117   1]
 [  0   0   0   5   3  27]]
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

Visualising Confusion Matrix of Test Set¶

In [ ]:
df_cm = pd.DataFrame(cm,waste_types,waste_types)

plt.figure(figsize=(10,8))
sns.heatmap(df_cm,annot=True,fmt="d",cmap="YlGnBu")
Out[ ]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f084460cdd0>
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
correct = 0
for r in range(len(cm)):
    for c in range(len(cm)):
        if (r==c):
            correct += cm[r,c]
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)
In [ ]:
accuracy = correct/sum(sum(cm))
accuracy
Out[ ]:
0.9413629160063391
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default.
  return _iterencode(o, 0)

We ended up achieving 94.1% accuracy which is slightly better than previous notebook. Also, we were able to work on the next steps mentioned by collindching's notebook and reduce misclassification error as well.

Comparing our confusion matrix with previous confusion matrix¶

Collindching's version of CM Vs My version

collindching cm.png

My Version my version.png

Google Collaboratory Link for the code and experiments

https://github.com/amay1212/Waste-Sorting/blob/master/Waste_Sorter%20Extended.ipynb

In [ ]:
# hide
## delete everything when you're done to save space
# shutil.rmtree("data")
# shutil.rmtree('dataset-resized')

Further enhancements and Research area¶

  1. To improve accuracy even further for the misclassified results.

  2. To try and test this model with back rep using the below

    https://www.mdpi.com/2313-433X/7/8/144/pdf#:~:text=BackRep%20consists%20of%20a%20data,solid%20waste%20is%20usually%20littered.

  3. To see if we can distinguish more clearly between dry and wet waste in particular.

In [ ]: